{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad4ae47e-6f18-414d-b04f-61f0dee4f7f7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install transformers tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ae2b727-1cb6-4721-b876-53686e81b199",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!HF_ENDPOINT=https://hf-mirror.com huggingface-cli download bert-base-chinese"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ace61491",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import BertTokenizer, BertModel\n",
    "from tqdm import tqdm\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3f36170",
   "metadata": {},
   "source": [
    "# 加载文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d741c1a5-db2a-4e5d-aa57-aa1fab744385",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "with open('文本.txt', 'r', encoding='utf-8') as file:\n",
    "  sentences = file.readlines()\n",
    "print('文本条数: ', len(sentences))\n",
    "print('预览第一条: ', sentences[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5530731",
   "metadata": {},
   "source": [
    "# 训练词向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8b6de19",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 加载预训练模型和tokenizer\n",
    "# 模型名字直接写入bert-base-chinese这个简化模型名就可以了，https://huggingface.co/google-bert/bert-base-chinese\n",
    "# 如果无法用梯子的话，可以本地下载：huggingface-cli download --resume-download bert-base-chinese\n",
    "model_name = \"bert-base-chinese\"\n",
    "# model_name = \"hfl/chinese-bert-wwm\"\n",
    "\n",
    "# 也可以试试使用哈工大的模型，model_name = \"hfl/chinese-bert-wwm\"\n",
    "# 注意提前需要下载huggingface-cli download --resume-download hfl/chinese-bert-wwm\n",
    "\n",
    "# 加载模型\n",
    "# 会从huggingface中下载模型\n",
    "# 源码：class PreTrainedModel(nn.Module....)\n",
    "# 所以，这里创建的既是PreTrainedModel类的实例，也是torch.nn.Module的实例\n",
    "# 对于警告Some weights of the model checkpoint，对与我们的任务，可以不用在意\n",
    "# 相关讨论：https://blog.csdn.net/PolarisRisingWar/article/details/123974645   https://huggingface.co/google-bert/bert-base-uncased/discussions/4\n",
    "model = BertModel.from_pretrained(model_name)\n",
    "\n",
    "# 加载tokenizer\n",
    "# 使用Tokenizer，就是为了将输入的句子加工为bert模型可以处理的格式\n",
    "tokenizer = BertTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0988189",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 将模型放置在GPU上\n",
    "# torch.cuda.is_available()，检测cuda是否可用\n",
    "# torch.device()设置张量运算在哪个设备上进行\n",
    "# device = torch.device(\"cpu\")，表示在CPU进行\n",
    "# device = torch.device(\"cuda\")，表示在GPU进行\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# 把模型放到cpu或gpu\n",
    "model.to(device)\n",
    "# 将模型设置为评估模式，https://blog.csdn.net/weixin_45275599/article/details/131524189\n",
    "model.eval() \n",
    "\n",
    "# 切分数据\n",
    "batch_size = 16  # 批大小\n",
    "data_loader = DataLoader(sentences, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b7579ac-30ae-4d76-baa8-8401d9a21748",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# ---- 文本转向量 ----\n",
    "# 生成的向量存放在这里\n",
    "cls_embeddings = []\n",
    "\n",
    "# 使用tqdm显示处理进度\n",
    "# tqdm b站教程：https://www.bilibili.com/video/BV1ZG411M7Ge/?spm_id_from=333.337.search-card.all.click&vd_source=eace37b0970f8d3d597d32f39dec89d8\n",
    "for batch_sentences in tqdm(data_loader):\n",
    "    # tokenizer官方文档：https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__\n",
    "    # truncation=True，对输入句子进行截断，这里确保最大长度不超过512个字\n",
    "    # max_length：不设置的话，默认会截断到该模型可接受的最大长度\n",
    "    # padding=True 或 padding='longest': 将所有句子填充到批次中最长句子的长度\n",
    "    # padding=\"max_length\": 将所有句子填充到由 max_length 参数指定的长度\n",
    "    inputs = tokenizer(batch_sentences, padding=True, truncation=True, return_tensors=\"pt\", max_length=512)\n",
    "    # print(123, inputs.input_ids[0], tokenizer.decode(inputs.input_ids[0]))\n",
    "    \n",
    "    # 把编码好的数据，也放在device上，It is necessary to have both the model, and the data on the same device, either CPU or GPU\n",
    "    # https://huggingface.co/docs/transformers/v4.39.2/en/main_classes/tokenizer#transformers.BatchEncoding.to\n",
    "    # https://stackoverflow.com/questions/63061779/pytorch-when-do-i-need-to-use-todevice-on-a-model-or-tensor\n",
    "    inputs.to(device)\n",
    "\n",
    "    # 设置不要计算梯度\n",
    "    # 一般来说，如果我们只是用模型进行“预测”，而不涉及对模型进行更新时，就不需要计算梯度，以此来节约内存，增加运算效率\n",
    "    # with上下文中，对model的调用将遵循torch.no_grad()，即不会计算梯度\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "\n",
    "    # 把这一批词向量存入cls_embeddings容器中\n",
    "    # tensor.cpu() 将张量移动到 CPU\n",
    "    # tensor.numpy() 将 CPU 上的张量转换为 NumPy 数组\n",
    "    cls_embeddings.append(outputs.last_hidden_state[:, 0].cpu().numpy()) # 只取CLS对应的向量\n",
    "\n",
    "# 合并句子向量\n",
    "cls_embeddings = np.vstack(cls_embeddings)\n",
    "print('最终生成的词向量', type(cls_embeddings), cls_embeddings.shape)\n",
    "\n",
    "# ---- 保存词嵌入向量 ----\n",
    "# 保存句子向量到npy文件\n",
    "# 官方文档：https://numpy.org/doc/stable/reference/generated/numpy.save.html\n",
    "output_file = \"emb.npy\"\n",
    "np.save(output_file, cls_embeddings)\n",
    "print(\"词向量存储于: \", output_file)\n",
    "\n",
    "embeddings = np.load(output_file)\n",
    "print(\"加载回来，验证一下：\", type(embeddings), embeddings.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
